{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "\n",
    "# tces_file = '/mnt/tess/astronet/tces-vetting-v4-toi-train.csv'\n",
    "# file_pattern = '/mnt/tess/astronet/tfrecords-vetting-5-toi-train/*'\n",
    "# model_name = 'AstroCNNModelVetting'\n",
    "# config_name = 'vrevised'\n",
    "# labels = ['p', 'e', 'n']\n",
    "tces_file = '/mnt/tess/astronet/tces-v14-val.csv'\n",
    "file_pattern = '/mnt/tess/astronet/tfrecords-38-val/*'\n",
    "model_name = 'AstroCNNModel'\n",
    "config_name = 'final_alpha_1'\n",
    "labels = ['E', 'N', 'J', 'S', 'B']\n",
    "\n",
    "filenames = tf.io.gfile.glob(file_pattern)\n",
    "    \n",
    "filenames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "tce_table = pd.read_csv(tces_file, header=0, low_memory=False)\n",
    "print(len(tce_table))\n",
    "tce_table.head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "series = {}\n",
    "\n",
    "for filename in filenames:\n",
    "  tfr = tf.data.TFRecordDataset(filename)\n",
    "  num_records = 0\n",
    "  for record in tfr:\n",
    "    num_records += 1\n",
    "    ex = tf.train.Example.FromString(record.numpy())\n",
    "    for k in ex.features.feature.keys():\n",
    "      f = ex.features.feature[k]\n",
    "      if f.int64_list.value:\n",
    "        v = f.int64_list.value[0]\n",
    "      elif f.float_list.value:\n",
    "        v = f.float_list.value[0]\n",
    "      elif f.bytes_list.value:\n",
    "        v = f.bytes_list.value[0].decode()\n",
    "      else:\n",
    "        continue\n",
    "\n",
    "      if k not in series:\n",
    "        series[k] = []\n",
    "      series[k].append(v)\n",
    "  print(filename, num_records)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "examples_table = pd.DataFrame.from_dict(series)\n",
    "\n",
    "pd.set_option('display.max_columns', None)\n",
    "# examples_table[['secondary_scale']].describe()\n",
    "examples_table.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "counts = [sum(examples_table['disp_{}'.format(l)] > 0) for l in labels]\n",
    "ax = plt.bar(labels, counts)\n",
    "for i in range(len(labels)):\n",
    "    b = ax[i]\n",
    "    height = b.get_height()\n",
    "    x, y = b.get_xy()\n",
    "    plt.annotate(\n",
    "        '{} - {:.0%}'.format(counts[i], counts[i] / sum(counts)),\n",
    "        (x + 0.1, y + height + 11))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "examples_table.head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tce_table[tce_table.index == 8209]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "print('Label mismtaches between TCE and tfrecords:')\n",
    "np.array(set(tce_table[tce_table[f'disp_{labels[0]}'] > 0]['Astro ID'].values)\n",
    "    - set(examples_table[examples_table[f'disp_{labels[0]}'] > 0]['astro_id'].values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "from astronet import models\n",
    "from astronet.astro_cnn_model import input_ds\n",
    "\n",
    "config = models.get_model_config(model_name, config_name)\n",
    "\n",
    "ds = input_ds.build_dataset(\n",
    "      file_pattern=file_pattern,\n",
    "      input_config=config.inputs,\n",
    "      batch_size=1,\n",
    "      include_labels=False,\n",
    "      shuffle_filenames=False,\n",
    "      repeat=1,\n",
    "      include_identifiers=True)\n",
    "labels_ds = input_ds.build_dataset(\n",
    "      file_pattern=file_pattern,\n",
    "      input_config=config.inputs,\n",
    "      batch_size=1,\n",
    "      include_labels=True,\n",
    "      shuffle_filenames=False,\n",
    "      repeat=1,\n",
    "      include_identifiers=True)\n",
    "labels_iter = iter(labels_ds)\n",
    "\n",
    "label_index = {k.lower(): i for i, k in enumerate(config.inputs.label_columns)}\n",
    "cols = [\"disp_E\", \"disp_N\", \"disp_J\", \"disp_S\", \"disp_B\"]\n",
    "\n",
    "all_ids = []\n",
    "bad_labels = []\n",
    "for d in ds:\n",
    "  lab = next(labels_iter)\n",
    "  \n",
    "  def lam(e):\n",
    "    if e.dtype == tf.int64:\n",
    "        return e\n",
    "    if tf.reduce_any(tf.math.is_nan(e)):\n",
    "        tf.print(e, summarize=-1)\n",
    "        raise ValueError('data has NaNs.')\n",
    "    return e\n",
    "  ex_id = d[1].numpy().item()\n",
    "  all_ids.append(ex_id)\n",
    "  \n",
    "  assert lab[0]['duration'] == d[0]['duration']\n",
    "  rec = tce_table[tce_table['Astro ID'] == ex_id]\n",
    "  for c in cols:\n",
    "    if (lab[1][0][label_index[c.lower()]].numpy() == 0) != (rec[c].values[0] == 0):\n",
    "      bad_labels.append(ex_id)\n",
    "      print('bad example: ', ex_id)\n",
    "      print(rec)\n",
    "      print(cols)\n",
    "      print(lab[1][0])\n",
    "      break\n",
    "  if bad_labels:\n",
    "    break\n",
    "  \n",
    "  try:\n",
    "    tf.nest.map_structure(lam, d)\n",
    "  except ValueError as e:\n",
    "    print(e)\n",
    "    print(d[1])\n",
    "    break\n",
    "else:\n",
    "  print('No NaNs or mismtached labels found.')\n",
    "\n",
    "if len(all_ids) == len(set(all_ids)):\n",
    "  print('No duplicates found.')\n",
    "else:\n",
    "  print('Found duplicates!', len(all_ids) - len(set(all_ids)))\n",
    "  print([t for t in set(all_ids) if all_ids.count(t) > 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def astro_id(tic_id):\n",
    "  return tce_table[tce_table['TIC ID'] == tic_id]['Astro ID'].values[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from astronet.preprocess import preprocess\n",
    "\n",
    "\n",
    "tess_data_dir = '/mnt/tess/lc'\n",
    "\n",
    "def find_tce(astro_id):\n",
    "  with tf.device('cpu'):\n",
    "    for filename in filenames:\n",
    "      tfr = tf.data.TFRecordDataset(filename)\n",
    "      for record in tfr:\n",
    "        ex = tf.train.Example.FromString(record.numpy())\n",
    "        if (ex.features.feature[\"astro_id\"].int64_list.value[0] == astro_id):\n",
    "          print('TIC ID:', tic_id)\n",
    "          for l in labels:\n",
    "              print(f'{l}:', ex.features.feature[f\"disp_{l}\"].int64_list.value[0])\n",
    "          print('Duration:', ex.features.feature[\"Duration\"].float_list.value[0])\n",
    "          return ex\n",
    "\n",
    "    raise ValueError(\"{} not found in files: {}\".format(astro_id, filenames))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = ds.cache()\n",
    "def plot_ds_tce(ds, astro_id):\n",
    "    for d in ds:\n",
    "        if d[1] == astro_id:\n",
    "            for k, v in d[0].items():\n",
    "                if k.startswith('local_'):\n",
    "                    continue\n",
    "                if k.startswith('global_'):\n",
    "                    continue\n",
    "                if k.startswith('secondary_'):\n",
    "                    continue\n",
    "                if k.startswith('sample_'):\n",
    "                    continue\n",
    "                print(f'{k:25}: {v.numpy()}')\n",
    "            global_view = np.array(d[0]['global_view'][0].numpy())\n",
    "            local_view = np.array(d[0]['local_view'][0].numpy())\n",
    "            secondary_view = np.array(d[0]['secondary_view'][0].numpy())\n",
    "            fig, axes = plt.subplots(2, 3, figsize=(20, 12))\n",
    "            axes[0, 0].plot(global_view, '.-')\n",
    "            axes[0, 1].plot(local_view, '.-')\n",
    "            axes[0, 2].plot(secondary_view, '.-')\n",
    "            axes[1, 0].plot(d[0]['global_mask'][0].numpy(), '.-')\n",
    "            axes[1, 1].plot(d[0]['global_view_0.3'][0].numpy(), '.-')\n",
    "            axes[1, 2].plot(d[0]['global_view_5.0'][0].numpy(), '.-')\n",
    "            plt.show()\n",
    "            plt.close('all')\n",
    "            return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tic_id = 349412074\n",
    "plot_ds_tce(ds, astro_id(tic_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tce_table[tce_table['TIC ID'] == tic_id]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "astro_id(tic_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "examples_table[examples_table['astro_id'] == astro_id(tic_id)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls /mnt/tess/lc-v | grep 237320326"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "tic_id = 334227600\n",
    "tce = find_tce(astro_id(tic_id))\n",
    "\n",
    "list(tce.features.feature.keys())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
